-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
Perception Encoder Integration #2478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
timm/models/pe.py
Outdated
elif freqs_for == "lang": | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | ||
elif freqs_for == "pixel": | ||
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pi
here should load from torch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, prefer to keep torch.pi vs math.pi and not import x.pi as pi ...
elif freqs_for == "constant": | ||
freqs = torch.ones(num_freqs).float() | ||
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The freqs here is a parameter that isn't in the original model, so their are complaints about this when loading state dict... I assume the behaviour in the pretrained model still matches current code? But for the option of having learned_freq, should this be...
theta *= theta_rescale_factor ** (dim / (dim - 2))
if freqs_for == "lang":
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
assert False
if learned_freq:
self.freqs = nn.Parameter(freqs)
else:
self.freqs = nn.Buffer(freqs, persistent=False)
timm/models/pe.py
Outdated
elif freqs_for == "lang": | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | ||
elif freqs_for == "pixel": | ||
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, prefer to keep torch.pi vs math.pi and not import x.pi as pi ...
attn_pooler_heads: int = 8, | ||
pool_type: Literal["attn", "tok", "avg", "none"] = "attn", | ||
num_classes: int = 0, # no use for PE | ||
in_chans: int = 3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do need to add support for a classifier, either in the PE module or wrap everything, otherwise default behaviour for adapting encoders as classifiers doesn't work so well ... I'll figure out how best to support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add classifier support (and reset) in the new commit. Current forward pass: [x -> Transformer(x)] -> [pool -> proj -> head (for classification)], with forward_features and forward_head respectively. Let's discuss more in the Slack (hf-fair-pe-collab). Thank you!
|
||
self.conv1 = nn.Conv2d( | ||
in_channels=3, | ||
out_channels=width, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 should be in_chans
|
||
|
||
class Rope2D(Module): | ||
def __init__(self, dim, grid_size, use_cls_token=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This module should be marked non traceable to pass FX tests as it looks like the if t.ndim == 3
will break tracing
See eg
pytorch-image-models/timm/models/xcit.py
Line 33 in c8c4f25
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method |
freq = torch.cat([freq, torch.zeros(1, freq.shape[-1])], dim=0) | ||
|
||
self.freq = Parameter(freq[None, ...]) # remark: using Parameter instead of tensor for device consistency | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also complaint about this parameter, was it originally not a parameter as it doesn't exist in state dicts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed rope freq to nn.Buffer(freqs, persistent=False). Thanks for the suggestion.
…s cfg/fxforward/fxbackward unitest
…ffer, add drop_rate arg
@berniebear sorry, silly typo in my comments that wasn't in my working hacks, its self.register_buffer not nn.Buffer, haha ... |
Add Perception Encoder to timm.
Intro
This PR aims to integrate Perception Encoder (paper, code) from FAIR to timm. We thank you for the support and feedback.
Perception Encoder Performance:
Vision-Language Benchmarks
Multimodal LLM Benchmarks
Vision-centric Benchmarks
Linear Probe
448px w/o TTA
Mask R-CNN 1024px
Box / Mask mAP
DETA 1824px
Box mAP
Proposed integration and changes:
Known issues/limitations:
PE models available hf_hub path
A. ViT only
B. CLIP (ViT + Text transformer. For future open_clip integration only)
Test plan (parity):
All the models supported and tested:
Note:
Thanks for all the support and feedback for this timm integration!